
# Load raw ephys data
load_raw <- function(file.name, raw.data.folder = "./Data/") {
  tryCatch({
    
    flog.info(paste0("Processing file: ", file.name), name = log.name)
    
    Temp <- data.frame()
    Temp <- read_delim(file = paste0(raw.data.folder, file.name), delim = ";")
    
    if ("File Name" %in% colnames(Temp)) {
      Temp <- Temp %>% rename(FileID = `File Name`)
    }
    
    Temp <- Temp %>% 
      rename(Time.start = `Bin Left`, Time.stop = `Bin Right`) %>%
      mutate(File.name = file.name, 
             File.name = gsub(pattern = ".csv", replacement = "", File.name))
    
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Temp)
  }) # end try catch
}


# Get session timings -----------------------------------------------------
get_session <- function(raw.data) {
  tryCatch({
    
    flog.info("Attempting to get sessions...", name = log.name)
    
    Sessions.timings <- data.frame()
    
    Sessions.timings <- raw.data %>% select(Time.start,
                                            starts_with("Baseline"), 
                                            starts_with("Vehicle"), 
                                            starts_with("BI_cmp"),
                                            starts_with("Saline"),
                                            starts_with("CNO")) %>%
      filter_at(vars(-starts_with("Time")), any_vars(. == 1)) %>% # get the onsets only
      gather(key = "Session", value = "val", -Time.start) %>%
      filter(val == 1) %>% # get the onsets only
      mutate(On.off = if_else(str_detect(string = Session, pattern = "_onset"), "Onset", "Offset"),
             Session = str_replace_all(string = Session, pattern = "_onset|_offset", replacement = "")) %>%
      select(-val) %>%
      spread(key = On.off, value = Time.start)
    
    stopifnot(Sessions.timings$Onset < Sessions.timings$Offset)
    
    return(Sessions.timings)
    
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Sessions.timings)
  }) # end try catch
}

# Seperate session --------------------------------------------------------

# additional funcion for z.score
poisson.zscore <- function(x) {
  mean.col <- mean(x, na.rm = TRUE)
  x <- (x - mean.col)/sqrt(mean.col)
  return(x)
}

seperate_session <- function(session, Session.timing, Data, min.spikes = 10, z.score.session = FALSE,
                             smooth.spikes = FALSE, smooth.window = 1,
                             spike.multiplier = 1, normalize.spikes.method ="zscore") {
  tryCatch({
    flog.info(paste0("Filtering session: ", session), name = log.name)
    
    flog.debug("Pulling start-stop idx", name = log.name)
    start.idx <- Session.timing %>% filter(Session == session) %>% pull(Onset)
    stop.idx <- Session.timing %>% filter(Session == session) %>% pull(Offset)
    
    stopifnot(start.idx < stop.idx)
    
    flog.debug("Filtering session", name = log.name)
    # filter the session
    Temp.day <- Data %>% filter(between(Time.start, start.idx, stop.idx)) %>%
      select(-starts_with("Baseline"), 
             -starts_with("Vehicle"), 
             -starts_with("BI_cmp"),
             -starts_with("Saline"),
             -starts_with("CNO")) %>%
      mutate(SessionID = session)
    
    
    flog.debug("Removing last trial", name = log.name)
    # filter out the begining and end of the file (leave time when CageViewer was ON)
    Last.ITI.idx <- Temp.day %>% filter(ITI_period_offset == 1) %>% 
      pull(Time.start) %>% max(.) # NOTE: this way the last trial is removed!
    First.light.idx <- Temp.day %>% filter(Light_onset_corrected == 1) %>% 
      pull(Time.start) %>% min(.)
    
    Temp.day <- Temp.day %>% 
      filter(between(Time.start, First.light.idx, Last.ITI.idx))
    
    flog.debug("Correcting ITI", name = log.name)
    # correct ITI
    Temp.day$ITI_period_offset[1] <- 0
    
    # Add trial counter via Ligh onset
    TrialCnt <- 0
    Temp.day$TrialCnt <- 0
    
    flog.debug("Adding Trial count", name = log.name)
    for (i in 1:nrow(Temp.day)) {
      if (Temp.day[i, "Light_onset_corrected"] > 0) {
        TrialCnt <- TrialCnt + 1
      }
      Temp.day[i, "TrialCnt"] <- TrialCnt
    }
    
    if (TrialCnt != 60) {
      flog.info(paste0("Trial Count: ", TrialCnt), name = log.name)
      flog.warn(paste0("Trial count not full: ", TrialCnt), name = log.name)
    }
    
    flog.debug("Remove artifact spikes", name = log.name)
    # remove artifact spikes and z.score
    Spikes <- Temp.day %>% select(starts_with("SPK")) %>%
      select_if(colSums(., na.rm = TRUE) > min.spikes)
    
    
    Spikes <- Spikes %>% 
      mutate_all(funs(. * spike.multiplier))
    
    # smooth spikes
    if (smooth.spikes) {
      flog.info(paste0("Smoothing spikes with window: ", smooth.window), name = log.name)
        
      Spikes <- Spikes %>% 
        mutate_all(funs(smth.gaussian), window = smooth.window, tails = TRUE)
      
    }
    
    # session-wise z-score
    if (z.score.session) {
      flog.info("Doing session-wise z-score", name = log.name)
      
      if (normalize.spikes.method == "zscore") {
        flog.info("Using standard z.score", name = log.name)
        Spikes <- Spikes %>% mutate_all(funs(scale))
        
      } else if (normalize.spikes.method == "poisson.zscore") {
        flog.info("Using poisson approximation for z.score", name = log.name)
        Spikes <- Spikes %>% mutate_all(funs(poisson.zscore))

      }
      
    }
    
    stopifnot(ncol(Spikes) > 0)
    
    flog.debug("Final merge", name = log.name)
    Temp.day <- bind_cols(Temp.day %>% select(-starts_with("SPK")), Spikes) %>%
      select(-Start, -Stop) 
    
    return(Temp.day)
    
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Temp.day)
  }) # end try catch
}

# Add Periods -------------------------------------------------------------
add_period <- function(Data){
  tryCatch({
    flog.info("Adding periods", name = log.name)
    
    Periods <- c("Correct_Go_Sound_onset", "Correct_Go_Sound_offset", 
                 "Correct_NG_Sound_onset", "Correct_NG_Sound_offset",
                 "FalseAlarm_Sound_onset", "FalseAlarm_Sound_offset",
                 "Omited_Go_Sound_onset", "Omited_Go_Sound_offset",
                 "Reward_period_onset", "Reward_period_offset",
                 "ITI_period_onset", "ITI_period_offset", 
                 "Precue_period_onset", "Precue_period_offset")
    
    Data$Period <- "None"
    
    flog.debug("Getting Period Timings", name = log.name)
    
    Period.timings <- Data %>% select(Time.start, TrialCnt,
                                      one_of(Periods)) %>%
      filter_at(vars(-starts_with("Time|TrialCnt")), any_vars(. == 1)) %>%
      gather(key = "Period", value = "val", -Time.start, -TrialCnt) %>%
      filter(val == 1) %>%
      group_by(Period, TrialCnt) %>%
      mutate(Cnt = seq(from = 1, to = n())) %>%
      ungroup() %>% 
      mutate(TrialCnt = if_else(Period == "ITI_period_offset", TrialCnt - 1, TrialCnt)) %>%
      mutate(On.off = if_else(str_detect(string = Period, pattern = "_onset"), "Onset", "Offset"),
             Period = str_replace_all(string = Period, pattern = "_onset|_offset", replacement = "")) %>%
      select(-val) %>% 
      spread(key = On.off, value = Time.start) %>%
      mutate(Delta = Offset - Onset)  %>%
      filter(complete.cases(.)) %>%
      arrange(desc(Period))
    
    
    # if (any(Period.timings$Delta[Period.timings$Period == "Correct_NG_Sound"] > 10)) {
    #   browser()
    # }
    # 
    if (any(Period.timings$Cnt != 1) | any(is.na(Period.timings$Delta))) {
      flog.info(paste0("Session having wrong periods/missing values: ",
                       unique(Data$FileID), " in ",
                       unique(Data$SessionID)), name = log.name)
    write_csv(Period.timings, path = paste0(output.folder, unique(Data$FileID), unique(Data$SessionID), ".csv"))
    
    Plot <- Data %>% select(Time.start, TrialCnt,
                            one_of(c(Periods, "Light_onset_corrected"))) %>%
      gather(key = "Period", value = "val", -Time.start, -TrialCnt) %>% 
      filter(val == 1) %>% 
      mutate(On.off = if_else(str_detect(string = Period, pattern = "_onset"), "Onset", "Offset"),
             Period = str_replace_all(string = Period, pattern = "_onset|_offset", replacement = ""),
             On.off = factor(On.off)) 
    
    
    plot <- ggplot() +
      geom_vline(data = Plot %>% filter(Period == "Light_corrected"), aes(xintercept = Time.start), alpha = 0.5) +
      geom_point(data = Plot %>% filter(Period != "Light_corrected"), aes(x = Time.start, y = Period, color = On.off)) +
      theme_bw() +
      scale_color_manual(values = c("red", "blue")) +
      geom_text(data = Plot %>% filter(Period == "Light_corrected"), 
                aes(x = Time.start, y = 5.5, label = TrialCnt), alpha = 0.5, size = 3)
    
    ggsave(filename = paste0(plots.folder, "Error_", 
                             unique(Data$FileID), 
                             unique(Data$SessionID), ".png"),
           width = 15, height = 5, dpi = 300, type = "cairo-png")
    
      
    }
    
    if (any(Period.timings$Cnt != 1)) {
      flog.info(paste0("Session having wrong periods: ",
                       unique(Data$FileID), " in ",
                       unique(Data$SessionID)), name = log.name)
      wrong <- Period.timings %>%
        filter(Cnt != 1)
      
      flog.info(wrong, name = log.name)
      
      Period.timings <- Period.timings  %>% 
        group_by(Period, TrialCnt) %>%
        slice(which.max(Cnt)) %>% 
        ungroup()
      
    }
    
    test_n <- Period.timings %>% 
      group_by(Period, TrialCnt) %>% 
      summarise(n = n())
    
    if (any(test_n$n != 1)) {
      flog.info(paste0("Session having wrong periods: ",
             unique(Data$FileID), " in ",
             unique(Data$SessionID)), name = log.name)
      
      browser()
    }
    
    flog.debug("Adding periods", name = log.name)
    for (i in 1:nrow(Period.timings)) {
      period.start <- Period.timings[[i, "Onset"]]
      period.stop <- Period.timings[[i, "Offset"]]
      Data$Period[Data$Time.start >= period.start & Data$Time.start <= period.stop] <- 
        Period.timings[[i, "Period"]]
    }
    
    flog.debug("Periods done", name = log.name)
    
    return(Data)
    
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Data)
  }) # end try catch
}


# Add brain area and Treatment ID ----------------------------------------------
# Adds area ID, Treatment ID and GATHERS the data
add_area_id <- function(Data){
  tryCatch({
    
    flog.info("Adding brain and Treatment IDs", name = log.name)
    
    Data <- Data %>% 
      mutate(Area.brain = if_else(str_detect(string =  File.name, pattern = "STN"), "STN", "SN"),
             SessionID = str_replace_all(string = SessionID, pattern = "BI_cmp_11", replacement = "BI_cmp_1"),
             TreatmentID = case_when(grepl(pattern = "Baseline", SessionID) ~ "Baseline",
                                     grepl(pattern = "Vehicle", SessionID) ~ "Vehicle",
                                     grepl(pattern = "BI_cmp", SessionID) ~ "BI_cmp",
                                     grepl(pattern = "CNO", SessionID) ~ "CNO",
                                     grepl(pattern = "Saline", SessionID) ~ "Saline")) %>% 
      gather(key = Unit, value = z.score, contains("SPK"))
    
    return(Data)
  },
  error = function(c){
    flog.error("ERROR!", name = log.name)
    flog.error(c, name = log.name)
    return(Data)
  }) # end try catch  
  
}